from torch import nn


class ResBlock(nn.Module):
    """残差块

    属性:
        expansion: 第二个卷积层中输出通道数对于输入通道数扩展的倍数
    """
    expansion = 1

    """
    初始化

    参数:
        in_channel: 输入通道数
        out_channel: 输出通道数
        stride: 步长
    """

    def __init__(self, in_channel, out_channel, stride=1):
        super(ResBlock, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channel, out_channel * self.expansion, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_channel * self.expansion)
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channel != out_channel * self.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channel, out_channel * self.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channel * self.expansion)
            )

    def forward(self, x):
        out1 = self.layer(x)
        out2 = self.shortcut(x)
        out = out1 + out2
        return out